{
 "cells": [
  {
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-01-15T11:23:18.821336Z",
     "start_time": "2025-01-15T11:22:59.885929Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "device = \"mps\" # the device to load the model onto\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.1\", device_map=\"auto\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.1\")"
   ],
   "id": "initial_id",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c62421232c9645e8811a5401278410e1"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some parameters are on the meta device because they were offloaded to the disk.\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-15T11:23:18.870246Z",
     "start_time": "2025-01-15T11:23:18.861544Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from transformers.models.mistral.modeling_mistral import MistralAttention\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import types\n",
    "from typing import Callable, List, Optional, Tuple, Union\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, Cache\n",
    "\n",
    "\n",
    "def rotate_half(x):\n",
    "    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n",
    "    x1 = x[..., : x.shape[-1] // 2]\n",
    "    x2 = x[..., x.shape[-1] // 2 :]\n",
    "    return torch.cat((-x2, x1), dim=-1)\n",
    "\n",
    "\n",
    "def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n",
    "    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n",
    "\n",
    "    Args:\n",
    "        q (`torch.Tensor`): The query tensor.\n",
    "        k (`torch.Tensor`): The key tensor.\n",
    "        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n",
    "        sin (`torch.Tensor`): The sine part of the rotary embedding.\n",
    "        position_ids (`torch.Tensor`, *optional*):\n",
    "            Deprecated and unused.\n",
    "        unsqueeze_dim (`int`, *optional*, defaults to 1):\n",
    "            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n",
    "            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n",
    "            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n",
    "            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n",
    "            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n",
    "            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n",
    "    Returns:\n",
    "        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n",
    "    \"\"\"\n",
    "    cos = cos.unsqueeze(unsqueeze_dim)\n",
    "    sin = sin.unsqueeze(unsqueeze_dim)\n",
    "    q_embed = (q * cos) + (rotate_half(q) * sin)\n",
    "    k_embed = (k * cos) + (rotate_half(k) * sin)\n",
    "    return q_embed, k_embed\n",
    "\n",
    "\n",
    "def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n",
    "    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n",
    "    \"\"\"\n",
    "    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n",
    "    if n_rep == 1:\n",
    "        return hidden_states\n",
    "    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n",
    "    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n",
    "\n",
    "def eager_attention_forward(\n",
    "    module: nn.Module,\n",
    "    query: torch.Tensor,\n",
    "    key: torch.Tensor,\n",
    "    value: torch.Tensor,\n",
    "    attention_mask: Optional[torch.Tensor],\n",
    "    scaling: float,\n",
    "    dropout: float = 0.0,\n",
    "    **kwargs,\n",
    "):\n",
    "    key_states = repeat_kv(key, module.num_key_value_groups)\n",
    "    value_states = repeat_kv(value, module.num_key_value_groups)\n",
    "\n",
    "    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n",
    "    if attention_mask is not None:\n",
    "        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n",
    "        attn_weights = attn_weights + causal_mask\n",
    "\n",
    "    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n",
    "    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n",
    "    attn_output = torch.matmul(attn_weights, value_states)\n",
    "    attn_output = attn_output.transpose(1, 2).contiguous()\n",
    "\n",
    "    return attn_output, attn_weights\n",
    "\n",
    "def custom_attention_forward(\n",
    "    self,\n",
    "    hidden_states: torch.Tensor,\n",
    "    position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n",
    "    attention_mask: Optional[torch.Tensor],\n",
    "    past_key_value: Optional[Cache] = None,\n",
    "    cache_position: Optional[torch.LongTensor] = None,\n",
    "    **kwargs,\n",
    ") -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n",
    "    input_shape = hidden_states.shape[:-1]\n",
    "    hidden_shape = (*input_shape, -1, self.head_dim)\n",
    "\n",
    "    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n",
    "    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n",
    "    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n",
    "\n",
    "    cos, sin = position_embeddings\n",
    "    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n",
    "\n",
    "    if past_key_value is not None:\n",
    "        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n",
    "        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n",
    "        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n",
    "\n",
    "    attn_output, attn_weights = eager_attention_forward(\n",
    "        self,\n",
    "        query_states,\n",
    "        key_states,\n",
    "        value_states,\n",
    "        attention_mask,\n",
    "        dropout=0.0 if not self.training else self.attention_dropout,\n",
    "        scaling=self.scaling,\n",
    "        sliding_window=getattr(self.config, \"sliding_window\", None),  # main diff with Llama\n",
    "        **kwargs,\n",
    "    )\n",
    "\n",
    "    attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n",
    "    attn_output = self.o_proj(attn_output)\n",
    "    return attn_output, attn_weights\n",
    "\n",
    "\n",
    "# # 2. 定义修改模型 attention 的函数\n",
    "# def enable_mistral_custom_attention(model):\n",
    "#     for name, module in reversed(model._modules.items()):\n",
    "#         if len(list(module.children())) > 0:\n",
    "#             enable_mistral_custom_attention(\n",
    "#                 module,\n",
    "#             )\n",
    "#\n",
    "#         if isinstance(module, MistralAttention):\n",
    "#             model._modules[name].forward = types.MethodType(\n",
    "#                 custom_attention_forward, model._modules[name]\n",
    "#             )\n",
    "\n",
    "def enable_mistral_custom_attention(model):\n",
    "    for module in model.modules():  # modules() 会自动递归遍历所有子模块\n",
    "        if isinstance(module, MistralAttention):\n",
    "            module.forward = types.MethodType(custom_attention_forward, module)"
   ],
   "id": "77f2b6b39b402e0d",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-15T11:23:18.875926Z",
     "start_time": "2025-01-15T11:23:18.873834Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 修改模型的 attention 实现\n",
    "enable_mistral_custom_attention(model)\n",
    "# model = model.to(device)"
   ],
   "id": "dd20e800a2f70ff9",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-15T11:24:48.146632Z",
     "start_time": "2025-01-15T11:24:19.869904Z"
    }
   },
   "cell_type": "code",
   "source": [
    "messages = [\n",
    "    {\"role\": \"user\", \"content\": \"What is your favourite condiment?\"},\n",
    "    {\"role\": \"assistant\", \"content\": \"Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!\"},\n",
    "    {\"role\": \"user\", \"content\": \"Do you have mayonnaise recipes?\"}\n",
    "]\n",
    "\n",
    "encodeds = tokenizer.apply_chat_template(messages, return_tensors=\"pt\")\n",
    "\n",
    "model_inputs = encodeds\n",
    "# model.to(device)\n",
    "\n",
    "generated_ids = model.generate(model_inputs, max_new_tokens=10, do_sample=True)\n",
    "decoded = tokenizer.batch_decode(generated_ids)\n",
    "print(decoded[0])\n"
   ],
   "id": "5f86935bfe05302c",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s> [INST] What is your favourite condiment? [/INST] Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!</s> [INST] Do you have mayonnaise recipes? [/INST] Of course! Here's a simple recipe for\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-15T11:23:59.720770Z",
     "start_time": "2025-01-15T11:23:59.718886Z"
    }
   },
   "cell_type": "code",
   "source": "",
   "id": "2b6c72c0be21c5f4",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
